import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
except ImportError:
    pass

__all__ = ['BCEDiceLoss', 'LovaszHingeLoss']


class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5    # 平滑因子，防止分母为0
        input = torch.sigmoid(input)    #  将输入的对数几率转化为概率值。
        num = target.size(0)
        input = input.view(num, -1)    # 将input转化为2维张量。
        target = target.view(num, -1)
        intersection = (input * target)   #input是概率，target是0/1，因此计算结果是预测的每个样本的正样本像素的概率。
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) # 预测结果和真实标签之间的相似度
        dice = 1 - dice.sum() / num   # 计算了平均Dice系数，并将其减去1，以得到Dice损失
        return 0.5 * bce + dice


class LovaszHingeLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        input = input.squeeze(1)
        target = target.squeeze(1)
        loss = lovasz_hinge(input, target, per_image=True)

        return loss
